# Copyright (c) 2022-2025, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""
This file contains tests for exporting TransformerEngine models to ONNX.

The purpose of these tests is validation that TE models are converted to their correct ONNX
representation. Toward this end, each test captures the output of a TE module forward pass,
converts the TE module to ONNX, and uses ONNX Runtime (ORT) to execute the ONNX graph and
validate the output against TE's output.

Until FP8 is introduced to the ONNX standard, FP8 QuantizeLinear/DequantizeLinear is implemented
using custom ORT operations.

To run many repetitive tests use pytest-loop:
    $ python3 -m pip install pytest-loop
    $ pytest --loop 1000 tests/pytorch/test_onnx_export.py::test_export_layernorm

For reproducibility use: torch.manual_seed(0)
"""

import os
import tempfile
import pytest
import warnings
import numpy as np
import onnxruntime as ort
import torch
from torch import nn as nn
from typing import Optional, Union, Tuple, List
from onnxruntime_extensions import PyCustomOpDef, get_library_path, onnx_op
import transformer_engine.pytorch as te
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from transformer_engine.pytorch.export import is_in_onnx_export_mode, te_translation_table
from transformer_engine.pytorch.quantization import FP8GlobalStateManager
from transformer_engine.pytorch.utils import get_default_init_method
import tensorrt as trt

# Global test configuration knobs.

# Enable this to serialize test inputs and outputs to file (as a Polygraphy RunResults instance).
SAVE_TEST_IO = bool(int(os.getenv("NVTE_ONNX_EXPORT_SAVE_TEST_IO", "0")))

if SAVE_TEST_IO:
    from polygraphy.json import save_json
    from polygraphy.comparator import RunResults

# The directory where generated ONNX test models are stored.
NVTE_TEST_ARTIFACTS_DIR = os.environ.get("NVTE_TEST_ARTIFACTS_DIR")
NVTE_TEST_ARTIFACTS_DIR = NVTE_TEST_ARTIFACTS_DIR or os.path.join(
    tempfile.gettempdir(), "./gen_onnx_models"
)


# The directory where this file is stored.
TESTS_DIR = os.path.dirname(os.path.abspath(__file__))

fp8_available, reason_for_no_fp8 = te.is_fp8_available(return_reason=True)
mxfp8_available, reason_for_no_mxfp8 = te.is_mxfp8_available(return_reason=True)

fp8_recipes = []
if mxfp8_available:
    fp8_recipes.append(recipe.MXFP8BlockScaling())
if fp8_available:
    fp8_recipes.append(recipe.DelayedScaling())
    fp8_recipes.append(recipe.Float8CurrentScaling())
fp8_recipes.append(None)

supported_activations = ["gelu", "relu", "reglu", "geglu", "swiglu", "clamped_swiglu"]

all_normalizations = ["LayerNorm", "RMSNorm"]


@onnx_op(
    op_type="trt::TRT_FP8QuantizeLinear",
    domain="trt",
    inputs=[
        PyCustomOpDef.dt_float,
        PyCustomOpDef.dt_float,
    ],
    outputs=[PyCustomOpDef.dt_uint8],
)
def trt_fp8_quantize(t, scale_inv):
    """FP8 quantization extension for ONNX Runtime."""
    x = torch.from_numpy(t).cuda()
    q = te.tensor.float8_tensor.Float8Quantizer(
        scale=1 / torch.from_numpy(scale_inv).cuda(),
        amax=torch.zeros([1]).cuda(),
        fp8_dtype=tex.DType.kFloat8E4M3,
    )
    return q(x)._data.cpu().numpy()


@onnx_op(
    op_type="trt::TRT_FP8DequantizeLinear",
    domain="trt",
    inputs=[
        PyCustomOpDef.dt_uint8,
        PyCustomOpDef.dt_float,
    ],
    outputs=[PyCustomOpDef.dt_float],
)
def trt_fp8_dequantize(t, scale_inv):
    """FP8 dequantization extension for ONNX Runtime."""
    x = torch.from_numpy(t).cuda()
    q = te.tensor.float8_tensor.Float8Quantizer(
        scale=1 / torch.from_numpy(scale_inv).cuda(),
        amax=torch.zeros([1]).cuda(),
        fp8_dtype=tex.DType.kFloat8E4M3,
    )
    quantizer_tensor = q.create_tensor_from_data(x, fake_dtype=torch.float32)
    return quantizer_tensor.dequantize().cpu().numpy()


@onnx_op(
    op_type="trt::TRT_MXFP8DynamicQuantize",
    domain="trt",
    inputs=[
        PyCustomOpDef.dt_float,
    ],
    outputs=[PyCustomOpDef.dt_uint8, PyCustomOpDef.dt_uint8],
)
def trt_mxfp8_quantize(t):
    """MXFP8 quantization extension for ONNX Runtime."""
    x = torch.from_numpy(t).cuda()
    q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3)
    return q(x)._rowwise_data.cpu().numpy(), q(x)._rowwise_scale_inv.cpu().numpy()


@onnx_op(
    op_type="trt::TRT_MXFP8DequantizeLinear",
    domain="trt",
    inputs=[
        PyCustomOpDef.dt_uint8,
        PyCustomOpDef.dt_uint8,
    ],
    outputs=[PyCustomOpDef.dt_float],
)
def trt_mxfp8_dequantize(t, scale_inv):
    """MXFP8 dequantization extension for ONNX Runtime."""
    x = torch.from_numpy(t).cuda()
    scale_inv_tensor = torch.from_numpy(scale_inv).cuda()
    q = te.tensor.mxfp8_tensor.MXFP8Quantizer(tex.DType.kFloat8E4M3)
    quantizer_tensor = q.create_tensor_from_data(x, scale_inv_tensor, fake_dtype=torch.float32)
    return quantizer_tensor.dequantize().cpu().numpy()


@pytest.fixture()
def seed_default_rng():
    """Reseed the PRNG for test reproducibility"""
    torch.manual_seed(1234)


@pytest.fixture()
def set_max_seq_len(max_seq_len=128):
    """Set the maximum sequence length that can be used for attention masking"""
    os.environ["NVTE_ONNX_KVCACHE_MAX_SEQ_LEN"] = f"{max_seq_len}"


@pytest.fixture(autouse=True)
def reset_global_fp8_state():
    yield
    FP8GlobalStateManager.reset()


def do_export(
    model: torch.nn.Module,
    inp: torch.Tensor,
    fname: str,
    fp8_recipe: recipe.Recipe,
    input_names: List[str] = None,
    output_names: List[str] = None,
    dynamic_shapes: List[str] = None,
):
    """Export to ONNX"""
    input_names = input_names or ["input"]
    output_names = output_names or ["output"]

    with torch.inference_mode(), te.autocast(
        enabled=fp8_recipe is not None, recipe=fp8_recipe
    ), warnings.catch_warnings():
        warnings.filterwarnings(action="ignore", category=torch.jit.TracerWarning, module=r".*")

        model.cuda().eval()
        os.makedirs(NVTE_TEST_ARTIFACTS_DIR, exist_ok=True)
        fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)

        inps = inp if isinstance(inp, list) or isinstance(inp, tuple) else (inp,)
        assert len(inps) == len(input_names)
        inds_to_del = [i for i in range(len(inps)) if inps[i] is None]
        input_names = [input_names[i] for i in range(len(inps)) if i not in inds_to_del]

        model(*inps)  # warm-up run
        with te.export.onnx_export(True):
            model(*inps)
        with te.export.onnx_export(True):
            torch.onnx.export(
                model,
                inps,
                fname,
                dynamo=True,
                custom_translation_table=te_translation_table,
                verbose=True,
                dynamic_shapes=dynamic_shapes,
                input_names=input_names,
                output_names=output_names,
                optimize=inps[0].dtype
                != torch.bfloat16,  # optimizer does not work with bfloat16 yet - will need to change that after onnxscript supports bfloat16
            )


def to_numpy(tensor):
    if isinstance(tensor, torch.Tensor):
        if tensor.dtype == torch.bfloat16:
            tensor = tensor.type(torch.float32)
        tensor = tensor.detach().cpu().numpy()
    return tensor


def set_layer_scale(module: torch.nn.Module, scale: float, num_gemms: int):
    """Initialize the FP8 quantization scales in module"""
    module.init_fp8_metadata(num_gemms)
    for quantizer in module.quantizers["scaling_fwd"]:
        quantizer.scale = torch.ones(1, dtype=torch.float32, device="cuda") * scale


def te_infer(
    model: torch.nn.Module,
    inps: Union[Tuple[torch.Tensor], torch.Tensor],
    is_fp8: bool,
    fp8_recipe: recipe.Recipe,
):
    """Transformer Engine forward propagation."""
    with torch.inference_mode(), te.autocast(
        enabled=is_fp8, recipe=fp8_recipe
    ), warnings.catch_warnings():
        te_outputs = model(*inps if isinstance(inps, tuple) else (inps,))
        if not isinstance(te_outputs, tuple):
            te_outputs = (te_outputs,)
        return te_outputs


def compare_outputs(
    onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
):
    """Compare ORT and TE outputs."""
    assert len(onnx_outputs) == len(te_outputs)
    # Compare ORT and PyTorch outputs.
    for onnx_output, te_output in zip(onnx_outputs, te_outputs):
        # np.isclose: abs(a - b) <= (atol + rtol * abs(b))
        te_output = to_numpy(te_output)
        onnx_output = to_numpy(onnx_output)
        ac = ~np.isclose(onnx_output, te_output, atol=atol, rtol=rtol)
        mismatches = ac.nonzero()
        mismatched_ids = [loc for loc in zip(*mismatches)]
        if mismatched_ids:
            # Log some information in case of error.
            print("*" * 100)
            nb_errors = len(mismatched_ids)
            nb_vals = min(nb_errors, max_errors_printed)
            print(f"Detected {nb_errors} diverging values (output shape={onnx_output.shape})")
            print(f"Showing first {nb_vals} errors (ONNX -- TE):")
            abs_err = np.abs(onnx_output - te_output)
            errors = abs_err[mismatches]
            for loc in mismatched_ids[:nb_vals]:
                ref = te_output[loc]
                print(
                    f"{onnx_output[loc]} -- {te_output[loc]} err={abs_err[loc]} >"
                    f" {atol + rtol * abs(ref)}"
                )
            print(f"Max error: {np.max(errors)}")
            if nb_errors > allow_cnt_errors:
                raise ValueError(f"Output validation of {fname} failed with {nb_errors} errors")


def serialize_inputs_outputs(
    fname: str,
    inputs: Union[Tuple[torch.Tensor], torch.Tensor],
    te_outputs: List[torch.Tensor],
    input_names: Optional[List[str]] = None,
    output_names: Optional[List[str]] = None,
):
    if not SAVE_TEST_IO:
        return

    fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)

    input_names = input_names or ["input"]
    output_names = output_names or ["output"]
    inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
    named_inputs = zip(input_names, inputs)
    input_data = [{k: v.cpu() for k, v in named_inputs if v is not None}]
    json_fname = fname[: -len(".onnx")] + "_inputs.json"
    save_json(input_data, json_fname, description="custom input data")

    json_fname = fname[: -len(".onnx")] + "_output.json"
    named_outputs = zip(output_names, te_outputs)
    output_data = {k: v.detach().cpu() for k, v in named_outputs if v is not None}
    custom_outputs = RunResults()
    custom_outputs.add([output_data], runner_name="custom_runner")
    custom_outputs.save(json_fname)


def validate_result(
    fname: str,
    inps: Union[Tuple[torch.Tensor], torch.Tensor],
    model: torch.nn.Module,
    atol: float = 1.0e-8,  # np.isclose default atol
    rtol: float = 1.0e-5,  # np.isclose default rtol
    max_errors_printed: int = 10,
    is_fp8: bool = False,
    allow_cnt_errors: int = 0,
    input_names: List[str] = None,
    output_names: List[str] = None,
    te_outputs: List[torch.Tensor] = None,
):
    """Compare the outputs of a Transformer Engine (TE) module vs the outputs of its ONNX
    representation using ONNX Runtime (ORT) and ensure they are close.

    The purpose of the output comparison is to validate that TE models are converted to
    their correct ONNX representation by testing that TE and ORT outputs match within some
    small threshold (allowing for finite precision errors).

    Argument `allow_cnt_errors` reduces test failure noise due to spurious errors by ignoring,
    a very small number (0-3) of outliers. This is fine to do because these outliers are due to
    small kernel implementation differences between TE and ORT and do not imply an incorrect ONNX
    representation (the tests assume both ORT or TE kernels are correct).

    Argument `te_outputs` can be used to provide pre-computed TE outputs.
    """

    def create_ort_session(fname: str, is_fp8: bool):
        def load_custom_ops(session_opts: ort.SessionOptions):
            """For FP8 validation with ORT we need to load our custom FP8 Q/DQ extension."""
            session_opts.register_custom_ops_library(get_library_path())
            print("registered custom FP8 Q/DQ ops!")

        """Create an ONNX Runtime session for validation."""
        kwargs = {"providers": ["CUDAExecutionProvider", "CPUExecutionProvider"]}
        if is_fp8:
            sess_options = ort.SessionOptions()
            load_custom_ops(sess_options)
            kwargs["sess_options"] = sess_options

        s = ort.InferenceSession(fname, **kwargs)
        return s

    def create_ort_input_dict(session, inputs):
        inputs = inputs if isinstance(inputs, list) or isinstance(inputs, tuple) else (inputs,)
        input_names = [x.name for x in session.get_inputs()]
        inps = [to_numpy(x) for x in inputs if x is not None]
        inp_dict = dict(zip(input_names, inps))
        return inp_dict

    input_names = input_names or ["input"]
    output_names = output_names or ["output"]

    # Run ORT session and TE model.
    fname = os.path.join(NVTE_TEST_ARTIFACTS_DIR, fname)
    if not te_outputs:
        te_outputs = te_infer(model, inps, is_fp8)
    ort_s = create_ort_session(fname, is_fp8)
    input_feed = create_ort_input_dict(ort_s, inps)
    onnx_outputs = ort_s.run(None, input_feed=input_feed)
    compare_outputs(
        onnx_outputs, te_outputs, atol, rtol, max_errors_printed, allow_cnt_errors, fname
    )


def dtype2str(dtype: torch.dtype, fake_bf16_io=False):
    if fake_bf16_io:
        assert dtype == torch.bfloat16
        return "_fake_bf16"
    return {
        torch.float32: "_fp32",
        torch.float16: "_fp16",
        torch.bfloat16: "_bf16",
    }[dtype]


def as_te_type(dtype: torch.dtype):
    return {
        torch.float32: tex.DType.kFloat32,
        torch.float16: tex.DType.kFloat16,
        torch.bfloat16: tex.DType.kBFloat16,
    }[dtype]


def get_attn_mask_str(use_mask, attn_mask_type):
    # See FusedScaleMaskSoftmax::forward_fused_softmax for logic behind names.
    if attn_mask_type is None:
        return "_mask" if use_mask else "_no-mask"
    attn_mask_str = "_arbitrary-no-mask"
    attn_mask_str = "_causal-mask" if attn_mask_type == "causal" else attn_mask_str
    attn_mask_str = (
        "_arbitrary-mask" if use_mask and attn_mask_type == "arbitrary" else attn_mask_str
    )
    return attn_mask_str


"""
Test cases begin here.
"""


def _test_export_linear(
    fp8_recipe: recipe.Recipe = fp8_recipes[0],
    use_bias: bool = True,
    return_bias: bool = False,
    precision: torch.dtype = torch.float32,
):
    if return_bias and not use_bias:
        pytest.skip("Cannot return bias when bias is disabled")

    # Set dimensions (these are arbitrary).
    batch_size = 4
    in_features = 64
    out_features = 64
    hidden_size = 64

    class Test_Linear(nn.Module):
        def __init__(self, in_features, out_features, use_bias, return_bias, precision):
            super().__init__()
            self.linear = te.Linear(
                in_features,
                out_features,
                bias=use_bias,
                return_bias=return_bias,
                params_dtype=precision,
            )

        def forward(self, inp):
            ret = self.linear(inp)
            return ret

    inp = torch.randn(batch_size, hidden_size, in_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if fp8_recipe is not None else ""
    bias_str = "_bias" if use_bias else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.linear{fp8_str}{bias_str}{high_prec_str}.onnx"
    with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
        model = Test_Linear(in_features, out_features, use_bias, return_bias, precision).to(
            device="cuda"
        )
        # dynamic shape
        bs = torch.export.Dim("bs", min=2, max=1256)
        do_export(
            model,
            inp,
            fname,
            fp8_recipe,
            dynamic_shapes={"inp": {0: bs}},
        )
        te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
        serialize_inputs_outputs(fname, inp, te_outputs)

        if precision in (torch.bfloat16,):
            return
        if fp8_recipe is None:
            validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
        else:
            validate_result(
                fname, inp, model, atol=1e-2, is_fp8=fp8_recipe is not None, te_outputs=te_outputs
            )


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_linear_recipe(seed_default_rng, fp8_recipe, precision):
    _test_export_linear(fp8_recipe=fp8_recipe, precision=precision)


@pytest.mark.parametrize("use_bias", [True, False])
def test_export_linear_use_bias(seed_default_rng, use_bias):
    _test_export_linear(use_bias=use_bias)


@pytest.mark.parametrize("return_bias", [True, False])
def test_export_linear_return_bias(seed_default_rng, return_bias):
    _test_export_linear(return_bias=return_bias)


def _test_export_layernorm(
    fp8_recipe: recipe.Recipe = fp8_recipes[0],
    precision: torch.dtype = torch.float32,
    zero_centered_gamma: bool = False,
    normalization: str = all_normalizations[0],
):
    # Set dimensions (these are arbitrary).
    batch_size = 4
    in_features = 64
    out_features = 256
    hidden_size = 256

    inp = torch.ones(batch_size, in_features, out_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if fp8_recipe is not None else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.layernorm_linear{fp8_str}{high_prec_str}.onnx"

    with torch.no_grad():
        with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
            layernorm_cls = te.LayerNorm if normalization == "LayerNorm" else te.RMSNorm
            model = layernorm_cls(
                hidden_size,
                params_dtype=precision,
                zero_centered_gamma=zero_centered_gamma,
            ).to(device="cuda")

            # dynamic shape
            bs = torch.export.Dim("bs", min=2, max=1256)
            do_export(model, inp, fname, fp8_recipe, dynamic_shapes={"input": {0: bs}})
            te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
            serialize_inputs_outputs(fname, inp, te_outputs)
            if precision in (torch.bfloat16,):
                return
            if fp8_recipe is None:
                validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
            elif precision != torch.bfloat16:
                validate_result(
                    fname,
                    inp,
                    model,
                    atol=1e-3,
                    is_fp8=fp8_recipe is not None,
                    te_outputs=te_outputs,
                )


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_recipe(seed_default_rng, fp8_recipe, precision):
    _test_export_layernorm(fp8_recipe=fp8_recipe, precision=precision)


def test_export_layernorm_zero_centered_gamma(seed_default_rng):
    _test_export_layernorm(zero_centered_gamma=True)


@pytest.mark.parametrize("normalization", all_normalizations)
def test_export_layernorm_normalization(seed_default_rng, normalization):
    _test_export_layernorm(normalization=normalization)


def _test_export_layernorm_linear(
    scale_factor: float = 112,
    fp8_recipe: recipe.Recipe = fp8_recipes[0],
    use_bias: bool = True,
    return_bias: bool = False,
    return_layernorm_output: bool = False,
    precision: torch.dtype = torch.float32,
    zero_centered_gamma: bool = False,
    normalization: str = all_normalizations[0],
):
    if return_bias and not use_bias:
        pytest.skip("Cannot return bias when bias is disabled")

    # Set dimensions (these are arbitrary).
    in_features = 64
    out_features = 256
    hidden_size = 256

    inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if fp8_recipe is not None else ""
    bias_str = "_bias" if use_bias else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.layernorm_linear{fp8_str}{bias_str}{high_prec_str}.onnx"

    with torch.no_grad():
        with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
            model = te.LayerNormLinear(
                hidden_size,
                3 * hidden_size,
                bias=use_bias,
                return_bias=return_bias,
                return_layernorm_output=return_layernorm_output,
                params_dtype=precision,
                zero_centered_gamma=zero_centered_gamma,
                normalization=normalization,
            ).to(device="cuda")
            if fp8_recipe is not None:
                set_layer_scale(model, scale_factor, num_gemms=2)
            do_export(model, inp, fname, fp8_recipe)

            te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
            serialize_inputs_outputs(fname, inp, te_outputs)
            if precision in (torch.bfloat16,):
                return
            if fp8_recipe is None:
                validate_result(fname, inp, model, atol=1e-3, te_outputs=te_outputs)
            elif precision != torch.bfloat16:
                validate_result(
                    fname,
                    inp,
                    model,
                    # For current scaling we use Float8Quantizer in tests + amax computed by hand,
                    # which has slightly different numerics than Float8CurrentScalingQuantizer.
                    atol=1e-3 if fp8_recipe.__class__ is not recipe.Float8CurrentScaling else 2e-2,
                    is_fp8=fp8_recipe is not None,
                    te_outputs=te_outputs,
                )


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_linear_recipe(seed_default_rng, fp8_recipe, precision):
    _test_export_layernorm_linear(fp8_recipe=fp8_recipe, precision=precision)


def test_export_layernorm_linear_return_ln_out(seed_default_rng):
    _test_export_layernorm_linear(return_layernorm_output=True)


def test_export_layernorm_linear_zero_centered_gamma(seed_default_rng):
    _test_export_layernorm_linear(zero_centered_gamma=True)


@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_linear_normalization(seed_default_rng, normalization):
    _test_export_layernorm_linear(normalization=normalization)


def test_export_layernorm_linear_no_bias(seed_default_rng):
    _test_export_layernorm_linear(use_bias=False)


def test_export_layernorm_linear_return_bias(seed_default_rng):
    _test_export_layernorm_linear(return_bias=True)


def _test_export_layernorm_mlp(
    scale_factor: float = 112,
    fp8_recipe: recipe.Recipe = fp8_recipes[0],
    use_bias: bool = True,
    return_bias: bool = False,
    return_layernorm_output: bool = False,
    precision: torch.dtype = torch.float32,
    zero_centered_gamma: bool = False,
    activation: str = supported_activations[0],
    normalization: str = all_normalizations[0],
):
    if return_bias and not use_bias:
        pytest.skip("Cannot return bias when bias is disabled")

    # Set dimensions (these are arbitrary).
    in_features = 64
    out_features = 256
    hidden_size = 256
    ffn_hidden_size = 256

    inp = torch.randn(in_features, out_features, device="cuda", dtype=precision)
    fp8_str = "_fp8" if fp8_recipe is not None else ""
    bias_str = "_bias" if use_bias else ""
    high_prec_str = dtype2str(precision)
    fname = f"te.layernorm_mlp{fp8_str}{bias_str}{high_prec_str}_{activation}.onnx"
    with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
        model = te.LayerNormMLP(
            hidden_size,
            ffn_hidden_size,
            bias=use_bias,
            return_bias=return_bias,
            return_layernorm_output=return_layernorm_output,
            params_dtype=precision,
            zero_centered_gamma=zero_centered_gamma,
            activation=activation,
            normalization=normalization,
        ).to(device="cuda")
        if fp8_recipe is not None:
            set_layer_scale(model, scale_factor, num_gemms=2)
        do_export(model, inp, fname, fp8_recipe)
        te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
        serialize_inputs_outputs(fname, inp, te_outputs)
        if precision in (torch.bfloat16,):
            return
        atol = (
            2e-2 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 1e-3)
        )  # TODO(pgadzinski) - check 2e-2
        validate_result(
            fname, inp, model, atol=atol, is_fp8=fp8_recipe is not None, te_outputs=te_outputs
        )


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_layernorm_mlp(seed_default_rng, fp8_recipe, precision):
    _test_export_layernorm_mlp(fp8_recipe=fp8_recipe, precision=precision)


def test_export_layernorm_mlp_return_layernorm_output(seed_default_rng):
    _test_export_layernorm_mlp(return_layernorm_output=True)


def test_export_layernorm_mlp_return_bias(seed_default_rng):
    _test_export_layernorm_mlp(return_bias=True)


def test_export_layernorm_mlp_no_bias(seed_default_rng):
    _test_export_layernorm_mlp(use_bias=False)


def test_export_layernorm_mlp_zero_centered_gamma(seed_default_rng):
    _test_export_layernorm_mlp(zero_centered_gamma=True)


@pytest.mark.parametrize("normalization", all_normalizations[1:])
def test_export_layernorm_mlp_normalization(seed_default_rng, normalization):
    _test_export_layernorm_mlp(normalization=normalization)


@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_layernorm_mlp_activation(seed_default_rng, activation):
    _test_export_layernorm_mlp(activation=activation)


@pytest.mark.parametrize(
    "precision,      use_mask, attn_mask_type",
    [
        (torch.float32, True, "arbitrary"),  # calls forward_torch_softmax (apply user mask)
        (torch.float32, False, "no_mask"),  # calls forward_torch_softmax (apply no mask)
        (torch.float16, False, "causal"),  # calls forward_torch_softmax (apply dynamic onnx mask)
        (torch.float16, True, "arbitrary"),  # calls forward_torch_softmax (apply user mask)
        (torch.float16, False, "no_mask"),  # calls forward_torch_softmax (apply no mask)
        (torch.bfloat16, False, "causal"),  # calls forward_torch_softmax (apply dynamic onnx mask)
        (torch.bfloat16, True, "arbitrary"),  # calls forward_torch_softmax (apply user mask)
        (torch.bfloat16, False, "no_mask"),  # calls forward_torch_softmax (apply no mask)
    ],
)
def test_export_core_attention(
    precision: torch.dtype,
    use_mask: bool,
    attn_mask_type: str,
):
    # Set dimensions (these are arbitrary).
    seq_len, batch_size, num_attention_heads, kv_channels = (64, 4, 1, 64)
    qkv_size = (seq_len, batch_size, num_attention_heads, kv_channels)
    qkv_format = "sbhd"

    query_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
    key_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
    value_layer = torch.randn(qkv_size, dtype=precision, device="cuda")
    input_names = ["query", "key", "value", "attention_mask"]
    attention_mask = None
    if use_mask:
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(batch_size, 1, 1, seq_len, device="cuda", dtype=precision)
        attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
    inp = (query_layer, key_layer, value_layer, attention_mask)

    mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    high_prec_str = dtype2str(precision)
    fname = f"te.core_attention{mask_str}{high_prec_str}.onnx"

    model = te.attention.DotProductAttention(
        num_attention_heads=num_attention_heads,
        kv_channels=kv_channels,
        attention_dropout=0.5,
        qkv_format=qkv_format,
        attn_mask_type=attn_mask_type,
    ).to(device="cuda")
    do_export(model, inp, fname, input_names=input_names, fp8_recipe=None)
    te_outputs = te_infer(model, inp, is_fp8=False, fp8_recipe=None)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
    if precision in (torch.bfloat16,):
        return
    validate_result(
        fname, inp, model, is_fp8=True, atol=1e-2, input_names=input_names, te_outputs=te_outputs
    )


test_configs_attention_type = [
    # "input_layernorm, attention_type, fuse_qkv_params"
    (True, "self", True),
    (False, "self", True),
    (True, "self", False),
    (False, "self", False),
    (True, "cross", True),
    (False, "cross", True),
    (True, "cross", False),
    (False, "cross", False),
]


def _test_export_multihead_attention(
    fp8_recipe: recipe.Recipe = fp8_recipes[0],
    use_mask: bool = True,
    precision: torch.dtype = torch.float32,
    input_layernorm: bool = True,
    attention_type: str = "self",
    fuse_qkv_params: bool = True,
):
    hidden_size = 256
    sequence_length = 128
    batch_size = 4
    num_attention_heads = 32
    kv_channels = 8
    attention_dropout = 0.1
    layernorm_epsilon = 1e-5
    init_method = output_layer_init_method = get_default_init_method()
    attention_args = (
        hidden_size,
        num_attention_heads,
        kv_channels,
        attention_dropout,
        layernorm_epsilon,
        init_method,
        output_layer_init_method,
    )
    attn_mask_type = "arbitrary" if use_mask else "no_mask"

    hidden_states_context = torch.randn(
        sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
    )
    attention_mask = None
    if use_mask and attn_mask_type != "causal":
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(
            batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
        )
        attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)

    encoder_output = None

    if attention_type == "cross":
        encoder_output = torch.randn(
            sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
        )

    fp8_str = "_fp8" if fp8_recipe is not None else ""
    dtype_str = dtype2str(precision)
    attn_type_str = "_self-attention" if attention_type == "self" else "_cross-attention"
    fuse_qkv_str = "_fused-qkv" if fuse_qkv_params else ""
    attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    input_ln_str = "_input-ln" if input_layernorm else ""
    fname = f"te.multihead_attention{fp8_str}{attn_mask_str}{attn_type_str}{input_ln_str}{fuse_qkv_str}{dtype_str}.onnx"

    model = te.MultiheadAttention(
        *attention_args,
        attn_mask_type=attn_mask_type,
        params_dtype=precision,
        return_layernorm_output=False,
        input_layernorm=input_layernorm,
        attention_type=attention_type,
        fuse_qkv_params=fuse_qkv_params,
        return_bias=True,
    ).to(device="cuda")

    inp_context = (hidden_states_context, attention_mask, encoder_output)
    input_names = ["hidden_states", "attention_mask", "encoder_output"]
    output_names = ["attention_output", "attention_bias"]
    seq = torch.export.Dim("seq", min=2, max=1256)
    bs = torch.export.Dim("bs", min=2, max=1256)
    do_export(
        model,
        inp_context,
        fname,
        fp8_recipe,
        input_names=input_names,
        output_names=output_names,
        dynamic_shapes={
            "hidden_states": {0: seq, 1: bs},
            "attention_mask": {2: seq, 0: bs} if use_mask else None,
            "encoder_output": {0: seq, 1: bs} if attention_type == "cross" else None,
        },
    )
    te_outputs = te_infer(model, inp_context, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
    serialize_inputs_outputs(
        fname, inp_context, te_outputs, input_names=input_names, output_names=output_names
    )
    if precision in (torch.bfloat16,):
        return

    if fp8_recipe is None:
        validate_result(
            fname,
            inp_context,
            model,
            atol=1e-3,
            input_names=input_names,
            output_names=output_names,
            te_outputs=te_outputs,
        )
    else:
        validate_result(
            fname,
            inp_context,
            model,
            atol=1e-2,
            is_fp8=fp8_recipe is not None,
            input_names=input_names,
            output_names=output_names,
            allow_cnt_errors=3,
            te_outputs=te_outputs,
        )

    # In GPT generative phase (inference) the input sequence is smaller than the maximum
    # allowed sequence length and we want to test this condition.
    # Pretend that we're in generative phase when it makes sense (causal mask and self-attention).
    is_generative_phase = attn_mask_type == "causal" and attention_type == "self"
    if is_generative_phase:
        seq_len_offset = 8
        hidden_states_generative = torch.randn(
            sequence_length - seq_len_offset,
            batch_size,
            hidden_size,
            dtype=precision,
            device="cuda",
        )
        inp_generative = (hidden_states_generative, attention_mask, encoder_output)
        if fp8_recipe is None:
            validate_result(
                fname,
                inp_generative,
                model,
                atol=1e-3,
                input_names=input_names,
                output_names=output_names,
            )
        else:
            validate_result(
                fname,
                inp_generative,
                model,
                atol=1e-2,
                is_fp8=fp8_recipe is not None,
                input_names=input_names,
                output_names=output_names,
                allow_cnt_errors=3,
            )


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_multihead_attention_recipe(fp8_recipe, precision):
    _test_export_multihead_attention(fp8_recipe=fp8_recipe, precision=precision)


def test_export_multihead_attention_no_mask():
    _test_export_multihead_attention(use_mask=False)


def test_export_multihead_attention_no_input_layernorm():
    _test_export_multihead_attention(input_layernorm=False)


def test_export_multihead_attention_cross_attn():
    _test_export_multihead_attention(attention_type="cross")


def test_export_multihead_attention_unfused_qkv_params():
    _test_export_multihead_attention(fuse_qkv_params=False)


def _test_export_transformer_layer(
    fp8_recipe: recipe.Recipe = fp8_recipes[0],
    use_mask: bool = True,
    attn_mask_type: str = "arbitrary",
    output_layernorm: bool = False,
    precision: torch.dtype = torch.float32,
    fuse_qkv_params: bool = True,
    zero_centered_gamma: bool = False,
    activation: str = supported_activations[0],
):
    # Layer configuration
    hidden_size = 64
    sequence_length = 128
    batch_size = 1
    ffn_hidden_size = 256
    num_attention_heads = 4

    input_tensor = torch.rand(
        sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
    )
    input_names = ["input", "attention_mask"]
    attention_mask = None
    if use_mask and attn_mask_type != "causal":
        # Generate a random mask with 50% probability for 0 or 1.
        probs = 0.5 * torch.ones(
            batch_size, 1, sequence_length, sequence_length, device="cuda", dtype=precision
        )
        attention_mask = torch.bernoulli(probs).to("cuda", dtype=torch.bool)
    inp = (input_tensor, attention_mask)

    fp8_str = "_fp8" if fp8_recipe is not None else ""
    fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
    high_prec_str = dtype2str(precision)
    attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    fname = f"te.transformer_layer{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}_{activation}.onnx"

    model = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        self_attn_mask_type=attn_mask_type,
        output_layernorm=output_layernorm,
        params_dtype=precision,
        fuse_qkv_params=fuse_qkv_params,
        zero_centered_gamma=zero_centered_gamma,
        activation=activation,
    ).to(device="cuda")
    do_export(model, inp, fname, fp8_recipe, input_names=input_names)
    te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
    serialize_inputs_outputs(
        fname,
        inp,
        te_outputs,
        input_names=input_names,
    )
    if precision in (torch.bfloat16,):
        return
    atol = 5e-1 if fp8_recipe is not None else (5e-1 if activation == "swiglu" else 5e-3)
    validate_result(
        fname,
        inp,
        model,
        atol=atol,
        is_fp8=fp8_recipe is not None,
        input_names=input_names,
        te_outputs=te_outputs,
    )


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float32, torch.float16, torch.bfloat16])
def test_export_transformer_layer_recipe(fp8_recipe, precision):
    _test_export_transformer_layer(fp8_recipe=fp8_recipe, precision=precision)


def test_export_transformer_layer_no_mask():
    _test_export_transformer_layer(use_mask=False)


def test_export_transformer_layer_output_layernorm():
    _test_export_transformer_layer(output_layernorm=True)


def test_export_transformer_layer_unfused_qkv_params():
    _test_export_transformer_layer(fuse_qkv_params=False)


def test_export_transformer_layer_zero_centered_gamma():
    _test_export_transformer_layer(zero_centered_gamma=True)


@pytest.mark.parametrize("activation", supported_activations[1:])
def test_export_transformer_layer_activation(activation):
    _test_export_transformer_layer(activation=activation)


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
@pytest.mark.parametrize("precision", [torch.float16, torch.bfloat16])
def test_export_gpt_generation(
    fp8_recipe: recipe.Recipe,
    precision: torch.dtype,
):
    """Test that the ONNX model can correctly handle inputs with different shapes and that
    the attention mask is adjusted on-the-fly to different sequence lengths.
    """

    # Layer configuration
    hidden_size = 64
    sequence_length = 128
    batch_size = 4
    ffn_hidden_size = 256
    num_attention_heads = 4
    attention_mask = None
    use_mask = True
    attn_mask_type = "causal"
    fuse_qkv_params = True
    output_layernorm = False

    fp8_str = "_fp8" if fp8_recipe is not None else ""
    fuse_qkv_params_str = "_fused-qkv" if fuse_qkv_params else ""
    high_prec_str = dtype2str(precision)
    attn_mask_str = get_attn_mask_str(use_mask, attn_mask_type)
    fname = f"te.transformer_layer_generative{fp8_str}{attn_mask_str}{fuse_qkv_params_str}{high_prec_str}.onnx"

    model = te.TransformerLayer(
        hidden_size,
        ffn_hidden_size,
        num_attention_heads,
        self_attn_mask_type=attn_mask_type,
        output_layernorm=output_layernorm,
        params_dtype=precision,
        fuse_qkv_params=fuse_qkv_params,
    ).to(device="cuda")

    # "Context phase": use full input sequence length
    input_names = ["input"]
    output_names = ["output"]
    input_tensor = torch.rand(
        sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
    )
    inp = (input_tensor,)
    # dynamic shape
    seq = torch.export.Dim("seq", min=2, max=1256)
    bs = torch.export.Dim("bs", min=2, max=1256)
    do_export(
        model,
        inp,
        fname,
        fp8_recipe,
        dynamic_shapes={"hidden_states": {0: seq, 1: bs}},
    )
    te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
    serialize_inputs_outputs(
        fname, inp, te_outputs, input_names=input_names, output_names=output_names
    )
    if precision not in (torch.bfloat16,):
        validate_result(
            fname,
            inp,
            model,
            atol=1e-2,
            is_fp8=fp8_recipe is not None,
            input_names=input_names,
            te_outputs=te_outputs,
        )

    # "Generative phase": use a single input (sequence len=1). For FP8 we need to pad the sequence to mult of 8 and for MXFP8 we need to pad to mult of 32.
    sequence_length = 1 if fp8_recipe is None else 32
    input_tensor = torch.rand(
        sequence_length, batch_size, hidden_size, dtype=precision, device="cuda"
    )
    inp = (input_tensor, attention_mask)
    te_outputs = te_infer(model, inp, is_fp8=fp8_recipe is not None, fp8_recipe=fp8_recipe)
    serialize_inputs_outputs(fname, inp, te_outputs, input_names=input_names)
    if precision not in (torch.bfloat16,):
        validate_result(
            fname,
            inp,
            model,
            atol=1e-2,
            is_fp8=fp8_recipe is not None,
            input_names=input_names,
            te_outputs=te_outputs,
        )


@pytest.mark.parametrize("enabled", [True, False])
def test_export_ctx_manager(enabled):
    assert is_in_onnx_export_mode() == False
    with te.onnx_export(enabled):
        assert is_in_onnx_export_mode() == enabled
    assert is_in_onnx_export_mode() == False


@pytest.mark.parametrize("fp8_recipe", fp8_recipes)
def test_trt_integration(fp8_recipe: recipe.Recipe):

    model = te.TransformerLayer(
        hidden_size=128,
        ffn_hidden_size=128,
        num_attention_heads=4,
    ).eval()

    if type(fp8_recipe) == recipe.Float8CurrentScaling:
        # TODO(pgadzinski): Attention does not work with TRT for FP8CurrentScaling
        model = te.LayerNormMLP(128, 128)

    inps = (torch.randn([16, 16, 128], device="cuda", requires_grad=False),)

    with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
        out_ref = model(*inps)

    onnx_fd, onnx_path = tempfile.mkstemp(suffix=".onnx")
    os.close(onnx_fd)
    try:
        with te.autocast(enabled=fp8_recipe is not None, recipe=fp8_recipe):
            with te.onnx_export(enabled=True):
                torch.onnx.export(
                    model,
                    inps,
                    onnx_path,
                    output_names=["output"],
                    dynamo=True,
                    custom_translation_table=te_translation_table,
                )

        os.system(f"trtexec --onnx={onnx_path} --saveEngine={onnx_path}.engine")

        # Run TRT engine
        logger = trt.Logger(trt.Logger.WARNING)
        runtime = trt.Runtime(logger)
        with open(onnx_path + ".engine", "rb") as f:
            engine_data = f.read()
        engine = runtime.deserialize_cuda_engine(engine_data)
        context = engine.create_execution_context()
        context.set_tensor_address(engine.get_tensor_name(0), inps[0].data_ptr())
        stream = torch.cuda.Stream()

        out = torch.zeros_like(out_ref)
        context.set_tensor_address("output", out.data_ptr())

        context.execute_async_v3(stream_handle=stream.cuda_stream)
        stream.synchronize()

        # Compare TRT and TE outputs
        atol = 5e-2 if fp8_recipe is not None else 1e-4
        rtol = 5e-2 if fp8_recipe is not None else 1e-4
        torch.testing.assert_close(out, out_ref, atol=atol, rtol=rtol)
    finally:
        try:
            os.remove(onnx_path)
        except FileNotFoundError:
            pass
